from __future__ import print_function
import argparse
import os
#import cPickle as pickle
import pickle
import random
import numpy as np
import csv

import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

from model import *
from dataset import TriangleDataset

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Image Classification Transformer')
parser.add_argument('--path', type=str, default=None)
parser.add_argument('--epoch', type=int, default=0)
args = parser.parse_args()

if args.epoch == 0:
    args.epoch = ''
else:
    args.epoch = '_'+str(args.epoch)

name = args.path.split('/')[1].split('_')

# args.model = name[2]
# args.dataset = name[1]
# args.iterations = int(name[3])
# args.transformer_dim = int(name[4])
# args.n_heads = int(name[5])
# args.n_rules = int(name[6])
# args.qk_dim = int(name[7])
# args.seed = int(name[-2])
# args.dot = True if "dot" in name else False
# args.no_cuda = False
# args.batch_size = 64
# args.gumbel = False

args.model = 'Compositional'
args.dataset = 'Triangles'
args.iterations = 4
args.transformer_dim = 256
args.n_heads = 4
args.n_rules = 4
args.qk_dim = 32
args.seed = 7
args.dot = True
args.no_cuda = False
args.batch_size = 64
args.gumbel = False

args.cuda = not args.no_cuda

print(args)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

set_seed(args.seed)

if args.dataset =="CIFAR10":
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                             shuffle=False, num_workers=2)
elif args.dataset =="CIFAR100":
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
    ])
    
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                            download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                             shuffle=False, num_workers=2)
elif args.dataset == 'Triangles':
    train_dataset = TriangleDataset(num_examples = 50000)
    test_dataset = TriangleDataset(num_examples = 10000)
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, num_workers = 2, shuffle = False)
    testloader = torch.utils.data.DataLoader(test_dataset, batch_size = args.batch_size, num_workers = 2, shuffle = False)

device = 'cuda' if args.cuda else 'cpu'

print("Loading Model")
net = Model(args)
state_dict = torch.load(os.path.join(args.path, f'checkpoints/model{args.epoch}.pt'))
net.load_state_dict(state_dict)
net = net.to(device)
print(net)

if args.dataset == "Triangles":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

def test():
    net.eval()

    correct = 0.
    total_loss = 0.
    total = 0.

    acts = []

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, act = net(inputs, True)
            acts.append(act)
            if args.dataset == 'Triangles':
                outputs = outputs.squeeze()
                prediction = (torch.sigmoid(outputs) >= 0.5).int()
                targets = targets.float()
            else:
                _, prediction = outputs.max(dim = 1)

            loss = criterion(outputs, targets)

            total_loss += loss.item() * targets.size(0)
            total += targets.size(0)
            correct += torch.eq(prediction, targets).sum().item()

    loss = total_loss / total
    accuracy = correct / total

    return loss, accuracy * 100, torch.cat(acts)


test_loss, test_acc, acts = test()
print(f"Loss: {test_loss:.3f}  |  Accuracy: {test_acc:.2f}")
print()

acts = acts.reshape(-1, args.n_heads, args.n_rules).detach().cpu().numpy()
searches = np.arange(args.n_heads) + 1
searches = np.repeat(searches, acts.shape[0])

print(searches)
print(searches.shape)
print(acts.shape)
# df = pd.DataFrame(acts, columns=[""])
# print(acts.shape)